#! /usr/bin/env python
# -*- coding: utf-8 -*-

"""A decoder that performs beam search.
This code is taken directly from Tensorflow
(https://github.com/tensorflow/tensorflow) and is copied temporarily
until it is available in a packaged Tensorflow version on pypi.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import tensorflow as tf

# from . import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_tf
from tensorflow.python.util import nest


class BeamSearchDecoderState(
    collections.namedtuple("BeamSearchDecoderState", ("cell_state", "log_probs",
                                                      "finished", "lengths"))):
    pass


class BeamSearchDecoderOutput(
    collections.namedtuple("BeamSearchDecoderOutput",
                           ("scores", "predicted_ids", "parent_ids"))):
    pass


def _tile_batch(t, multiplier):
    """Core single-tensor implementation of tile_batch."""
    t = ops.convert_to_tensor(t, name="t")
    shape_t = tf.shape(t)
    if t.shape.ndims is None or t.shape.ndims < 1:
        raise ValueError("t must have statically known rank")
    tiling = [1] * (t.shape.ndims + 1)
    tiling[1] = multiplier
    tiled_static_batch_size = (
        t.shape[0].value * multiplier if t.shape[0].value is not None else None)
    tiled = tf.tile(tf.expand_dims(t, 1), tiling)
    tiled = tf.reshape(
        tiled, tf.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
    tiled.set_shape(
        tensor_shape.TensorShape(
            [tiled_static_batch_size]).concatenate(t.shape[1:]))
    return tiled


def tile_batch(t, multiplier, name=None):
    """Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
    For each tensor t in a (possibly nested structure) of tensors,
    this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
    minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
    `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
    `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
    `multiplier` times.
    Args:
      t: `Tensor` shaped `[batch_size, ...]`.
      multiplier: Python int.
      name: Name scope for any created operations.
    Returns:
      A (possibly nested structure of) `Tensor` shaped
      `[batch_size * multiplier, ...]`.
    Raises:
      ValueError: if tensor(s) `t` do not have a statically known rank or
      the rank is < 1.
    """
    flat_t = nest.flatten(t)
    with tf.name_scope(name, "tile_batch", flat_t + [multiplier]):
        return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)


def _check_maybe(t):
    if isinstance(t, tensor_tf.TensorArray):
        raise TypeError(
            "TensorArray state is not supported by BeamSearchDecoder: %s" % t.name)
    if t.shape.ndims is None:
        raise ValueError(
            "Expected tensor (%s) to have known rank, but ndims == None." % t)


class BeamSearchDecoder(decoder.Decoder):
    """BeamSearch sampling decoder."""

    def __init__(self,
                 cell,
                 embedding,
                 start_tokens,
                 end_token,
                 initial_state,
                 beam_width,
                 output_layer=None,
                 length_penalty_weight=0.0):
        """Initialize BeamSearchDecoder.
        Args:
          cell: An `RNNCell` instance.
          embedding: A callable that takes a vector tensor of `ids` (argmax ids),
            or the `params` argument for `embedding_lookup`.
          start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
          end_token: `int32` scalar, the token that marks end of decoding.
          initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
          output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
            `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
            to storing the result or sampling.
          length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
        Raises:
          TypeError: if `cell` is not an instance of `RNNCell`,
            or `output_layer` is not an instance of `tf.layers.Layer`.
          ValueError: If `start_tokens` is not a vector or
            `end_token` is not a scalar.
        """
        if not rnn_cell_impl._like_rnncell(cell):
            raise TypeError(
                "cell must be an RNNCell, received: %s" % type(cell))
        if (output_layer is not None
                and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError(
                "output_layer must be a Layer, received: %s" % type(output_layer))
        self._cell = cell
        self._output_layer = output_layer

        if callable(embedding):
            self._embedding_fn = embedding
        else:
            self._embedding_fn = (
                lambda ids: embedding_ops.embedding_lookup(embedding, ids))

        self._start_tokens = ops.convert_to_tensor(
            start_tokens, dtype=dtypes.int32, name="start_tokens")
        if self._start_tokens.get_shape().ndims != 1:
            raise ValueError("start_tokens must be a vector")
        self._end_token = ops.convert_to_tensor(
            end_token, dtype=dtypes.int32, name="end_token")
        if self._end_token.get_shape().ndims != 0:
            raise ValueError("end_token must be a scalar")

        self._batch_size = tf.size(start_tokens)
        self._beam_width = beam_width
        self._length_penalty_weight = length_penalty_weight
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams,
            initial_state, self._cell.state_size)
        self._start_tokens = tf.tile(
            tf.expand_dims(self._start_tokens, 1), [1, self._beam_width])
        self._start_inputs = self._embedding_fn(self._start_tokens)

    def _rnn_output_size(self):
        size = self._cell.output_size
        if self._output_layer is None:
            return size
        else:
            # To use layer's compute_output_shape, we need to convert the
            # RNNCell's output_size entries into shapes with an unknown
            # batch size.  We then pass this through the layer's
            # compute_output_shape and read off all but the first (batch)
            # dimensions to get the output size of the rnn with the layer
            # applied to the top.
            output_shape_with_unknown_batch = nest.map_structure(
                lambda s: tensor_shape.TensorShape([None]).concatenate(s),
                size)
            layer_output_shape = self._output_layer._compute_output_shape(  # pylint: disable=protected-access
                output_shape_with_unknown_batch)
            return nest.map_structure(lambda s: s[1:], layer_output_shape)

    @property
    def output_size(self):
        # Return the cell output and the id
        return BeamSearchDecoderOutput(
            scores=tensor_shape.TensorShape([self._beam_width]),
            predicted_ids=tensor_shape.TensorShape([self._beam_width]),
            parent_ids=tensor_shape.TensorShape([self._beam_width]))

    @property
    def output_dtype(self):
        # Assume the dtype of the cell is the output_size structure
        # containing the input_state's first component's dtype.
        # Return that structure and int32 (the id)
        dtype = nest.flatten(self._initial_cell_state)[0].dtype
        return BeamSearchDecoderOutput(
            scores=nest.map_structure(
                lambda _: dtype, self._rnn_output_size()),
            predicted_ids=dtypes.int32,
            parent_ids=dtypes.int32)

    def initialize(self, name=None):
        initial_state = BeamSearchDecoderState(
            cell_state=self._initial_cell_state,
            log_probs=tf.zeros(
                [self._batch_size, self._beam_width],
                dtype=nest.flatten(self._initial_cell_state)[0].dtype),
            finished=finished,
            lengths=tf.zeros(
                [self._batch_size, self._beam_width], dtype=dtypes.int32))

        return (finished, start_inputs, initial_state)

    def finalize(self, outputs, final_state, sequence_lengths):
        """Finalize and return the predicted_ids.
        Args:
          final_state: An instance of BeamSearchDecoderState. Passed through to the
            output.
          sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`.
            The sequence lengths determined for each beam during decode.
        Returns:
          outputs: An instance of FinalBeamSearchDecoderOutput where the
            predicted_ids are the result of calling _gather_tree.
          final_state: The same input instance of BeamSearchDecoderState.
        """
        predicted_ids = gather_tree(
            outputs.predicted_ids, outputs.parent_ids,
            sequence_length=sequence_lengths)

        outputs = FinalBeamSearchDecoderOutput(
            beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
        return outputs, final_state

    def _merge_batch_beams(self, t, s=None):
        """Merges the tensor from a batch of beams into a batch by beams.
        More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We
        reshape this into [batch_size*beam_width, s]
        Args:
          t: Tensor of dimension [batch_size, beam_width, s]
          s: (Possibly known) depth shape.
        Returns:
          A reshaped version of t with dimension [batch_size * beam_width, s].
        """
        if isinstance(s, ops.Tensor):
            s = tensor_shape.as_shape(tensor_util.constant_value(s))
        else:
            s = tensor_shape.TensorShape(s)
        t_shape = tf.shape(t)
        static_batch_size = tensor_util.constant_value(self._batch_size)
        batch_size_beam_width = (
            None if static_batch_size is None
            else static_batch_size * self._beam_width)
        reshaped_t = tf.reshape(
            t, tf.concat(
                ([self._batch_size * self._beam_width], t_shape[2:]), 0))
        reshaped_t.set_shape(
            (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s)))
        return reshaped_t

    def _split_batch_beams(self, t, s=None):
        """Splits the tensor from a batch by beams into a batch of beams.
        More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
        reshape this into [batch_size, beam_width, s]
        Args:
          t: Tensor of dimension [batch_size*beam_width, s].
          s: (Possibly known) depth shape.
        Returns:
          A reshaped version of t with dimension [batch_size, beam_width, s].
        Raises:
          ValueError: If, after reshaping, the new tensor is not shaped
            `[batch_size, beam_width, s]` (assuming batch_size and beam_width
            are known statically).
        """
        if isinstance(s, ops.Tensor):
            s = tensor_shape.TensorShape(tensor_util.constant_value(s))
        else:
            s = tensor_shape.TensorShape(s)
        t_shape = tf.shape(t)
        reshaped_t = tf.reshape(
            t, tf.concat(
                ([self._batch_size, self._beam_width], t_shape[1:]), 0))
        static_batch_size = tensor_util.constant_value(self._batch_size)
        expected_reshaped_shape = tensor_shape.TensorShape(
            [static_batch_size, self._beam_width]).concatenate(s)
        if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape):
            raise ValueError("Unexpected behavior when reshaping between beam width "
                             "and batch size.  The reshaped tensor has shape: %s.  "
                             "We expected it to have shape "
                             "(batch_size, beam_width, depth) == %s.  Perhaps you "
                             "forgot to create a zero_state with "
                             "batch_size=encoder_batch_size * beam_width?"
                             % (reshaped_t.shape, expected_reshaped_shape))
        reshaped_t.set_shape(expected_reshaped_shape)
        return reshaped_t

    def _maybe_split_batch_beams(self, t, s):
        """Maybe splits the tensor from a batch by beams into a batch of beams.
        We do this so that we can use nest and not run into problems with shapes.
        Args:
          t: Tensor of dimension [batch_size*beam_width, s]
          s: Tensor, Python int, or TensorShape.
        Returns:
          Either a reshaped version of t with dimension
          [batch_size, beam_width, s] if t's first dimension is of size
          batch_size*beam_width or t if not.
        Raises:
          TypeError: If t is an instance of TensorArray.
          ValueError: If the rank of t is not statically known.
        """
        _check_maybe(t)
        if t.shape.ndims >= 1:
            return self._split_batch_beams(t, s)
        else:
            return t

    def _maybe_merge_batch_beams(self, t, s):
        """Splits the tensor from a batch by beams into a batch of beams.
        More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
        reshape this into [batch_size, beam_width, s]
        Args:
          t: Tensor of dimension [batch_size*beam_width, s]
          s: Tensor, Python int, or TensorShape.
        Returns:
          A reshaped version of t with dimension [batch_size, beam_width, s].
        Raises:
          TypeError: If t is an instance of TensorArray.
          ValueError:  If the rank of t is not statically known.
        """
        _check_maybe(t)
        if t.shape.ndims >= 2:
            return self._merge_batch_beams(t, s)
        else:
            return t

    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.
        Args:
          time: scalar `int32` tensor.
          inputs: A (structure of) input tensors.
          state: A (structure of) state tensors and TensorArrays.
          name: Name scope for any created operations.
        Returns:
          `(outputs, next_state, next_inputs, finished)`.
        """
        batch_size = self._batch_size
        beam_width = self._beam_width
        end_token = self._end_token
        length_penalty_weight = self._length_penalty_weight

        with tf.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
            cell_state = state.cell_state
            inputs = nest.map_structure(
                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
            cell_state = nest.map_structure(
                self._maybe_merge_batch_beams,
                cell_state, self._cell.state_size)
            cell_outputs, next_cell_state = self._cell(inputs, cell_state)
            cell_outputs = nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
            next_cell_state = nest.map_structure(
                self._maybe_split_batch_beams,
                next_cell_state, self._cell.state_size)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            beam_search_output, beam_search_state = _beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=batch_size,
                beam_width=beam_width,
                end_token=end_token,
                length_penalty_weight=length_penalty_weight)

            finished = beam_search_state.finished
            sample_ids = beam_search_output.predicted_ids
            next_inputs = tf.cond(
                tf.reduce_all(finished), lambda: self._start_inputs,
                lambda: self._embedding_fn(sample_ids))

        return (beam_search_output, beam_search_state, next_inputs, finished)


def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
                      beam_width, end_token, length_penalty_weight):
    """Performs a single step of Beam Search Decoding.
    Args:
      time: Beam search time step, should start at 0. At time 0 we assume
        that all beams are equal and consider only the first beam for
        continuations.
      logits: Logits at the current time step. A tensor of shape
        `[batch_size, beam_width, vocab_size]`
      next_cell_state: The next state from the cell, e.g. an instance of
        AttentionWrapperState if the cell is attentional.
      beam_state: Current state of the beam search.
        An instance of `BeamSearchDecoderState`.
      batch_size: The batch size for this input.
      beam_width: Python int.  The size of the beams.
      end_token: The int32 end token.
      length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
    Returns:
      A new beam state.
    """
    static_batch_size = tensor_util.constant_value(batch_size)

    # Calculate the current lengths of the predictions
    prediction_lengths = beam_state.lengths
    previously_finished = beam_state.finished

    # Calculate the total log probs for the new hypotheses
    # Final Shape: [batch_size, beam_width, vocab_size]
    step_log_probs = nn_ops.log_softmax(logits)
    step_log_probs = _mask_probs(
        step_log_probs, end_token, previously_finished)
    total_probs = tf.expand_dims(
        beam_state.log_probs, axis=2) + step_log_probs

    # Calculate the continuation lengths by adding to all continuing beams.
    vocab_size = logits.shape[-1].value
    lengths_to_add = tf.one_hot(
        indices=tf.tile(
            tf.reshape(end_token, [1, 1]), [batch_size, beam_width]),
        depth=vocab_size,
        on_value=0,
        off_value=1)
    add_mask = (1 - tf.to_int32(previously_finished))
    lengths_to_add = tf.expand_dims(add_mask, 2) * lengths_to_add
    new_prediction_lengths = (
        lengths_to_add + tf.expand_dims(prediction_lengths, 2))

    # Calculate the scores for each beam
    scores = _get_scores(
        log_probs=total_probs,
        sequence_lengths=new_prediction_lengths,
        length_penalty_weight=length_penalty_weight)

    time = ops.convert_to_tensor(time, name="time")
    # During the first time step we only consider the initial beam
    scores_shape = tf.shape(scores)
    scores_flat = tf.cond(
        time > 0,
        lambda: tf.reshape(scores, [batch_size, -1]),
        lambda: scores[:, 0])
    num_available_beam = tf.cond(
        time > 0, lambda: tf.reduce_prod(scores_shape[1:]),
        lambda: tf.reduce_prod(scores_shape[2:]))

    # Pick the next beams according to the specified successors function
    next_beam_size = tf.minimum(
        ops.convert_to_tensor(
            beam_width, dtype=dtypes.int32, name="beam_width"),
        num_available_beam)
    next_beam_scores, word_indices = nn_ops.top_k(
        scores_flat, k=next_beam_size)
    next_beam_scores.set_shape([static_batch_size, beam_width])
    word_indices.set_shape([static_batch_size, beam_width])

    # Pick out the probs, beam_ids, and states according to the chosen
    # predictions
    next_beam_probs = _tensor_gather_helper(
        gather_indices=word_indices,
        gather_from=total_probs,
        batch_size=batch_size,
        range_size=beam_width * vocab_size,
        gather_shape=[-1])
    next_word_ids = tf.to_int32(word_indices % vocab_size)
    next_beam_ids = tf.to_int32(word_indices / vocab_size)

    # Append new ids to current predictions
    previously_finished = _tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=previously_finished,
        batch_size=batch_size,
        range_size=beam_width,
        gather_shape=[-1])
    next_finished = tf.logical_or(previously_finished,
                                  tf.equal(next_word_ids, end_token))

    # Calculate the length of the next predictions.
    # 1. Finished beams remain unchanged
    # 2. Beams that are now finished (EOS predicted) remain unchanged
    # 3. Beams that are not yet finished have their length increased by 1
    lengths_to_add = tf.to_int32(
        tf.not_equal(next_word_ids, end_token))
    lengths_to_add = (1 - tf.to_int32(next_finished)) * lengths_to_add
    next_prediction_len = _tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=beam_state.lengths,
        batch_size=batch_size,
        range_size=beam_width,
        gather_shape=[-1])
    next_prediction_len += lengths_to_add

    # Pick out the cell_states according to the next_beam_ids. We use a
    # different gather_shape here because the cell_state tensors, i.e.
    # the tensors that would be gathered from, all have dimension
    # greater than two and we need to preserve those dimensions.
    next_cell_state = nest.map_structure(
        lambda gather_from: _maybe_tensor_gather_helper(
            gather_indices=next_beam_ids,
            gather_from=gather_from,
            batch_size=batch_size,
            range_size=beam_width,
            gather_shape=[batch_size * beam_width, -1]),
        next_cell_state)

    next_state = BeamSearchDecoderState(
        cell_state=next_cell_state,
        log_probs=next_beam_probs,
        lengths=next_prediction_len,
        finished=next_finished)

    output = BeamSearchDecoderOutput(
        scores=next_beam_scores,
        predicted_ids=next_word_ids,
        parent_ids=next_beam_ids)

    return output, next_state


def _mask_probs(probs, eos_token, finished):
    vocab_size = tf.shape(probs)[2]
    finished_mask = tf.expand_dims(
        tf.to_float(1. - tf.to_float(finished)), axis=2)


def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
                                range_size, gather_shape):
    """Maybe applies _tensor_gather_helper.
    This applies _tensor_gather_helper when the gather_from dims is at least as
    big as the length of gather_shape. This is used in conjunction with nest so
    that we don't apply _tensor_gather_helper to inapplicable values like scalars.
    Args:
      gather_indices: The tensor indices that we use to gather.
      gather_from: The tensor that we are gathering from.
      batch_size: The batch size.
      range_size: The number of values in each range. Likely equal to beam_width.
      gather_shape: What we should reshape gather_from to in order to preserve the
        correct values. An example is when gather_from is the attention from an
        AttentionWrapperState with shape [batch_size, beam_width, attention_size].
        There, we want to preserve the attention_size elements, so gather_shape is
        [batch_size * beam_width, -1]. Then, upon reshape, we still have the
        attention_size as desired.
    Returns:
      output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
        or the original tensor if its dimensions are too small.
    """
    _check_maybe(gather_from)
    if gather_from.shape.ndims >= len(gather_shape):
        return _tensor_gather_helper(
            gather_indices=gather_indices,
            gather_from=gather_from,
            batch_size=batch_size,
            range_size=range_size,
            gather_shape=gather_shape)
    else:
        return gather_from


def _tensor_gather_helper(gather_indices, gather_from, batch_size,
                          range_size, gather_shape):
    """Helper for gathering the right indices from the tensor.
    This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
    gathering from that according to the gather_indices, which are offset by
    the right amounts in order to preserve the batch order.
    Args:
      gather_indices: The tensor indices that we use to gather.
      gather_from: The tensor that we are gathering from.
      batch_size: The input batch size.
      range_size: The number of values in each range. Likely equal to beam_width.
      gather_shape: What we should reshape gather_from to in order to preserve the
        correct values. An example is when gather_from is the attention from an
        AttentionWrapperState with shape [batch_size, beam_width, attention_size].
        There, we want to preserve the attention_size elements, so gather_shape is
        [batch_size * beam_width, -1]. Then, upon reshape, we still have the
        attention_size as desired.
    Returns:
      output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
    """
    range_ = tf.expand_dims(tf.range(batch_size) * range_size, 1)
    gather_indices = tf.reshape(gather_indices + range_, [-1])
    output = tf.gather(
        tf.reshape(gather_from, gather_shape), gather_indices)
    final_shape = tf.shape(gather_from)[:1 + len(gather_shape)]
    static_batch_size = tensor_util.constant_value(batch_size)
    final_static_shape = (tensor_shape.TensorShape([static_batch_size])
                          .concatenate(
                              gather_from.shape[1:1 + len(gather_shape)]))
    output = tf.reshape(output, final_shape)
    output.set_shape(final_static_shape)
    return output
